{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 支持向量机"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "分离超平面：$w^Tx+b=0$\n",
    "\n",
    "点到直线距离：$r=\\frac{|w^Tx+b|}{||w||_2}$\n",
    "\n",
    "$||w||_2$为2-范数：$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n",
    "\n",
    "直线为超平面，样本可表示为：\n",
    "\n",
    "$w^Tx+b\\ \\geq+1$\n",
    "\n",
    "$w^Tx+b\\ \\leq+1$\n",
    "\n",
    "#### margin：\n",
    "\n",
    "**函数间隔**：$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n",
    "\n",
    "**几何间隔**：$r=\\frac{label(w^Tx+b)}{||w||_2}$，当数据被正确分类时，几何间隔就是点到超平面的距离\n",
    "\n",
    "为了求几何间隔最大，SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔，(${r^*}$为函数间隔)\n",
    "\n",
    "$$\\max\\ \\frac{r^*}{||w||}$$\n",
    "\n",
    "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n",
    "\n",
    "分类点几何间隔最大，同时被正确分类。但这个方程并非凸函数求解，所以要先①将方程转化为凸函数，②用拉格朗日乘子法和KKT条件求解对偶问题。\n",
    "\n",
    "①转化为凸函数：\n",
    "\n",
    "先令${r^*}=1$，方便计算（参照衡量，不影响评价结果）\n",
    "\n",
    "$$\\max\\ \\frac{1}{||w||}$$\n",
    "\n",
    "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n",
    "\n",
    "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数，1/2是为了求导之后方便计算。\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n",
    "\n",
    "②用拉格朗日乘子法和KKT条件求解最优值：\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n",
    "\n",
    "整合成：\n",
    "\n",
    "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n",
    "\n",
    "推导：$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n",
    "\n",
    "根据KKT条件：\n",
    "\n",
    "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n",
    "\n",
    "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n",
    "\n",
    "带入$ L(w, b, \\alpha)$\n",
    "\n",
    "$\\min\\  L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n",
    "\n",
    "再把max问题转成min问题：\n",
    "\n",
    "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "以上为SVM对偶问题的对偶形式\n",
    "\n",
    "-----\n",
    "#### kernel\n",
    "\n",
    "在低维空间计算获得高维空间的计算结果，也就是说计算结果满足高维（满足高维，才能说明高维下线性可分）。\n",
    "\n",
    "#### soft margin & slack variable\n",
    "\n",
    "引入松弛变量$\\xi\\geq0$，对应数据点允许偏离的functional margin 的量。\n",
    "\n",
    "目标函数：$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n",
    "\n",
    "对偶问题：\n",
    "\n",
    "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n",
    "\n",
    "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n",
    "\n",
    "-----\n",
    "\n",
    "#### Sequential Minimal Optimization\n",
    "\n",
    "首先定义特征到结果的输出函数：$u=w^Tx+b$.\n",
    "\n",
    "因为$w=\\sum\\alpha_iy_ix_i$\n",
    "\n",
    "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n",
    "\n",
    "\n",
    "----\n",
    "\n",
    "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "-----\n",
    "参考资料：\n",
    "\n",
    "[1] :[Lagrange Multiplier and KKT](http://blog.csdn.net/xianlingmao/article/details/7919597)\n",
    "\n",
    "[2] :[推导SVM](https://my.oschina.net/dfsj66011/blog/517766)\n",
    "\n",
    "[3] :[机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.org/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/)\n",
    "\n",
    "[4] :[Python实现SVM](http://blog.csdn.net/wds2006sdo/article/details/53156589)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Anaconda3\\lib\\site-packages\\sklearn\\cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x2c52a378be0>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x2c52a2dafd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class SVM:\n",
    "    def __init__(self, max_iter=100, kernel='linear'):\n",
    "        self.max_iter = max_iter\n",
    "        self._kernel = kernel\n",
    "    \n",
    "    def init_args(self, features, labels):\n",
    "        self.m, self.n = features.shape\n",
    "        self.X = features\n",
    "        self.Y = labels\n",
    "        self.b = 0.0\n",
    "        \n",
    "        # 将Ei保存在一个列表里\n",
    "        self.alpha = np.ones(self.m)\n",
    "        self.E = [self._E(i) for i in range(self.m)]\n",
    "        # 松弛变量\n",
    "        self.C = 1.0\n",
    "        \n",
    "    def _KKT(self, i):\n",
    "        y_g = self._g(i)*self.Y[i]\n",
    "        if self.alpha[i] == 0:\n",
    "            return y_g >= 1\n",
    "        elif 0 < self.alpha[i] < self.C:\n",
    "            return y_g == 1\n",
    "        else:\n",
    "            return y_g <= 1\n",
    "    \n",
    "    # g(x)预测值，输入xi（X[i]）\n",
    "    def _g(self, i):\n",
    "        r = self.b\n",
    "        for j in range(self.m):\n",
    "            r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n",
    "        return r\n",
    "    \n",
    "    # 核函数\n",
    "    def kernel(self, x1, x2):\n",
    "        if self._kernel == 'linear':\n",
    "            return sum([x1[k]*x2[k] for k in range(self.n)])\n",
    "        elif self._kernel == 'poly':\n",
    "            return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n",
    "    \n",
    "        return 0\n",
    "    \n",
    "    # E（x）为g(x)对输入x的预测值和y的差\n",
    "    def _E(self, i):\n",
    "        return self._g(i) - self.Y[i]\n",
    "    \n",
    "    def _init_alpha(self):\n",
    "        # 外层循环首先遍历所有满足0<a<C的样本点，检验是否满足KKT\n",
    "        index_list = [i for i in range(self.m) if 0 < self.alpha[i] < self.C]\n",
    "        # 否则遍历整个训练集\n",
    "        non_satisfy_list = [i for i in range(self.m) if i not in index_list]\n",
    "        index_list.extend(non_satisfy_list)\n",
    "        \n",
    "        for i in index_list:\n",
    "            if self._KKT(i):\n",
    "                continue\n",
    "            \n",
    "            E1 = self.E[i]\n",
    "            # 如果E2是+，选择最小的；如果E2是负的，选择最大的\n",
    "            if E1 >= 0:\n",
    "                j = min(range(self.m), key=lambda x: self.E[x])\n",
    "            else:\n",
    "                j = max(range(self.m), key=lambda x: self.E[x])\n",
    "            return i, j\n",
    "        \n",
    "    def _compare(self, _alpha, L, H):\n",
    "        if _alpha > H:\n",
    "            return H\n",
    "        elif _alpha < L:\n",
    "            return L\n",
    "        else:\n",
    "            return _alpha      \n",
    "    \n",
    "    def fit(self, features, labels):\n",
    "        self.init_args(features, labels)\n",
    "        \n",
    "        for t in range(self.max_iter):\n",
    "            # train\n",
    "            i1, i2 = self._init_alpha()\n",
    "            \n",
    "            # 边界\n",
    "            if self.Y[i1] == self.Y[i2]:\n",
    "                L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n",
    "                H = min(self.C, self.alpha[i1]+self.alpha[i2])\n",
    "            else:\n",
    "                L = max(0, self.alpha[i2]-self.alpha[i1])\n",
    "                H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n",
    "                \n",
    "            E1 = self.E[i1]\n",
    "            E2 = self.E[i2]\n",
    "            # eta=K11+K22-2K12\n",
    "            eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n",
    "            if eta <= 0:\n",
    "                # print('eta <= 0')\n",
    "                continue\n",
    "                \n",
    "            alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E2 - E1) / eta\n",
    "            alpha2_new = self._compare(alpha2_new_unc, L, H)\n",
    "            \n",
    "            alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n",
    "            \n",
    "            b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n",
    "            b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n",
    "            \n",
    "            if 0 < alpha1_new < self.C:\n",
    "                b_new = b1_new\n",
    "            elif 0 < alpha2_new < self.C:\n",
    "                b_new = b2_new\n",
    "            else:\n",
    "                # 选择中点\n",
    "                b_new = (b1_new + b2_new) / 2\n",
    "                \n",
    "            # 更新参数\n",
    "            self.alpha[i1] = alpha1_new\n",
    "            self.alpha[i2] = alpha2_new\n",
    "            self.b = b_new\n",
    "            \n",
    "            self.E[i1] = self._E(i1)\n",
    "            self.E[i2] = self._E(i2)\n",
    "        return 'train done!'\n",
    "            \n",
    "    def predict(self, data):\n",
    "        r = self.b\n",
    "        for i in range(self.m):\n",
    "            r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n",
    "            \n",
    "        return 1 if r > 0 else -1\n",
    "    \n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        for i in range(len(X_test)):\n",
    "            result = self.predict(X_test[i])\n",
    "            if result == y_test[i]:\n",
    "                right_count += 1\n",
    "        return right_count / len(X_test)\n",
    "    \n",
    "    def _weight(self):\n",
    "        # linear model\n",
    "        yx = self.Y.reshape(-1, 1)*self.X\n",
    "        self.w = np.dot(yx.T, self.alpha)\n",
    "        return self.w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "svm = SVM(max_iter=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'train done!'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.96"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## sklearn.svm.SVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
       "  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n",
       "  max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
       "  tol=0.001, verbose=False)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.svm import SVC\n",
    "clf = SVC()\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### sklearn.svm.SVC\n",
    "\n",
    "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n",
    "\n",
    "参数：\n",
    "\n",
    "- C：C-SVC的惩罚参数C?默认值是1.0\n",
    "\n",
    "C越大，相当于惩罚松弛变量，希望松弛变量接近0，即对误分类的惩罚增大，趋向于对训练集全分对的情况，这样对训练集测试时准确率很高，但泛化能力弱。C值小，对误分类的惩罚减小，允许容错，将他们当成噪声点，泛化能力较强。\n",
    "\n",
    "- kernel ：核函数，默认是rbf，可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n",
    "    \n",
    "    – 线性：u'v\n",
    "    \n",
    "    – 多项式：(gamma*u'*v + coef0)^degree\n",
    "\n",
    "    – RBF函数：exp(-gamma|u-v|^2)\n",
    "\n",
    "    – sigmoid：tanh(gamma*u'*v + coef0)\n",
    "\n",
    "\n",
    "- degree ：多项式poly函数的维度，默认是3，选择其他核函数时会被忽略。\n",
    "\n",
    "\n",
    "- gamma ： ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’，则会选择1/n_features\n",
    "\n",
    "\n",
    "- coef0 ：核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n",
    "\n",
    "\n",
    "- probability ：是否采用概率估计？.默认为False\n",
    "\n",
    "\n",
    "- shrinking ：是否采用shrinking heuristic方法，默认为true\n",
    "\n",
    "\n",
    "- tol ：停止训练的误差值大小，默认为1e-3\n",
    "\n",
    "\n",
    "- cache_size ：核函数cache缓存大小，默认为200\n",
    "\n",
    "\n",
    "- class_weight ：类别的权重，字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n",
    "\n",
    "\n",
    "- verbose ：允许冗余输出？\n",
    "\n",
    "\n",
    "- max_iter ：最大迭代次数。-1为无限制。\n",
    "\n",
    "\n",
    "- decision_function_shape ：‘ovo’, ‘ovr’ or None, default=None3\n",
    "\n",
    "\n",
    "- random_state ：数据洗牌时的种子值，int值\n",
    "\n",
    "\n",
    "主要调节的参数有：C、kernel、degree、gamma、coef0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
